/*
 * Wazuh Vulnerability scanner
 * Copyright (C) 2015, Wazuh Inc.
 * March 25, 2023.
 *
 * This program is free software; you can redistribute it
 * and/or modify it under the terms of the GNU General Public
 * License (version 2) as published by the FSF - Free Software
 * Foundation.
 */

#include "vulnerabilityScannerFacade.hpp"
#include "archiveHelper.hpp"
#include "flatbuffers/include/syscollector_deltas_generated.h"
#include "flatbuffers/include/syscollector_synchronization_generated.h"
#include "loggerHelper.h"
#include "messageBuffer_generated.h"
#include "vulnerabilityScanner.hpp"
#include "wazuh_modules/vulnerability_scanner/src/policyManager/policyManager.hpp"
#include "xzHelper.hpp"

constexpr auto VULNERABILITY_SCANNER_TEMPLATE = "templates/vd_states_template.json";
constexpr auto DEFAULT_QUEUE_PATH = "queue/sockets/queue";
constexpr auto REPORTS_QUEUE_PATH = "queue/vd/reports";
constexpr auto REPORTS_BULK_SIZE {1};
constexpr auto DELAYED_EVENTS_BULK_SIZE {1};
constexpr auto DELAYED_EVENTS_QUEUE_PATH = "queue/vd/delayed";
constexpr auto SLEEP_RETRY_THREADS_SEC = 60;
constexpr auto MICROSEC_FACTOR {1000000};

int SOCKET_WAIT = 0;
constexpr auto COMPRESSED_DB_PATH {"tmp/vd_1.0.0_vd_4.8.0.tar.xz"};
constexpr auto DECOMPRESSED_DB_PATH {"tmp/vd_1.0.0_vd_4.8.0.tar"};
constexpr auto VD_STATE_QUEUE_PATH = "queue/vd/state_track";
constexpr auto VD_KEYSTORE_PATH = "queue/keystore";
constexpr auto VD_DATABASE_PATH {"queue/vd"};
constexpr auto VD_DATABASE_VERSION_KEY {"installed_content"};

static int64_t getSecondsFromEpoch()
{
    return std::chrono::duration_cast<std::chrono::seconds>(std::chrono::system_clock::now().time_since_epoch())
        .count();
};

bool VulnerabilityScannerFacade::decompressDatabase(const std::string& databaseVersion)
{
    bool ret = false;

    // Check database version. It will attempt to decompress the database
    // if the version does not match or the state_track does not have information
    if ((databaseVersion.compare(__ossec_version) != 0 || databaseVersion.empty()) && std::filesystem::exists(TMP_DIR))
    {
        // Check for XZ compressed file.
        if (!std::filesystem::exists(COMPRESSED_DB_PATH))
        {
            logWarn(WM_VULNSCAN_LOGTAG, "Missing database compressed file. Check DOWNLOAD_CONTENT option.");
            return ret;
        }

        logInfo(WM_VULNSCAN_LOGTAG, "Starting database file decompression.");
        logDebug2(WM_VULNSCAN_LOGTAG, "Starting XZ file decompression.");

        // Decompress XF file format.
        Utils::XzHelper(std::filesystem::path(COMPRESSED_DB_PATH), std::filesystem::path(DECOMPRESSED_DB_PATH))
            .decompress();

        // Clean up feed database.
        std::filesystem::remove_all(DATABASE_PATH);

        // Extract queue/vd and queue/vd_updater
        std::vector<std::string> extractOnly;
        extractOnly.emplace_back(VD_DATABASE_PATH);

        // Decompress also keystore if missing.
        if (!std::filesystem::exists(VD_KEYSTORE_PATH))
        {
            extractOnly.emplace_back(VD_KEYSTORE_PATH);
        }

        logDebug2(WM_VULNSCAN_LOGTAG, "Starting TAR file decompression.");

        // Decompress TAR file format.
        Utils::ArchiveHelper::decompress(DECOMPRESSED_DB_PATH, m_shouldStop, "", extractOnly);

        // Clean up.
        std::filesystem::remove_all(DECOMPRESSED_DB_PATH);

        if (!m_shouldStop.load())
        {
            ret = true;
            logInfo(WM_VULNSCAN_LOGTAG, "Database decompression finished.");
        }
    }

    return ret;
}

void VulnerabilityScannerFacade::initAlertReportDispatcher()
{
    const auto alertsMaxEps = PolicyManager::instance().getAlertsMaxEventsPerSecond();
    const auto reportsWait = alertsMaxEps > 0 ? MICROSEC_FACTOR / alertsMaxEps : 0;

    m_reportSocketClient =
        std::make_shared<SocketClient<Socket<OSPrimitives, NoHeaderProtocol>, EpollWrapper>>(DEFAULT_QUEUE_PATH);
    m_reportSocketClient->connect(
        [](const char* data, uint32_t size, const char* dataHeader, uint32_t sizeHeader) {}, []() {}, SOCK_DGRAM);
    m_reportDispatcher = std::make_shared<ReportDispatcher>(
        [this, reportsWait](std::queue<std::string>& dataQueue)
        {
            while (!dataQueue.empty())
            {
                auto& data = dataQueue.front();
                m_reportSocketClient->send(data.c_str(), data.size());
                // We wait to keep the maximum number of events per second
                if (reportsWait > 0)
                {
                    std::this_thread::sleep_for(std::chrono::microseconds(reportsWait));
                }
                logDebug2(WM_VULNSCAN_LOGTAG, "Report sent: %s", data.c_str());
                dataQueue.pop();
            }
        },
        REPORTS_QUEUE_PATH,
        REPORTS_BULK_SIZE);
}

void VulnerabilityScannerFacade::initDelayedEventsDispatcher(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    m_delayedDispatcher = std::make_shared<DelayedEventDispatcher>(
        // coverity[copy_constructor_call]
        [&, scanOrchestrator](std::queue<rocksdb::PinnableSlice>& dataQueue)
        {
            auto& returnSlice = dataQueue.front();

            flatbuffers::Verifier verifier(reinterpret_cast<const uint8_t*>(returnSlice.data()), returnSlice.size());

            if (VerifyMessageBufferBuffer(verifier))
            {
                auto message = GetMessageBuffer(returnSlice.data());

                // If the message is newer than <SLEEP_RETRY_THREADS_SEC> second, wait to retry.
                const auto secondsFromEpoch = getSecondsFromEpoch();
                if ((message->timestamp() + SLEEP_RETRY_THREADS_SEC) > secondsFromEpoch)
                {
                    std::unique_lock<std::mutex> lock(m_retryMutex);
                    if (m_retryWait.wait_for(
                            lock,
                            std::chrono::seconds(message->timestamp() + SLEEP_RETRY_THREADS_SEC - secondsFromEpoch),
                            [&] { return m_shouldStop.load(); }))
                    {
                        logDebug1(WM_VULNSCAN_LOGTAG, "Wait canceled. Event postponed.");
                        return;
                    }
                }

                if (message->type() == BufferType::BufferType_RSync)
                {
                    std::variant<const SyscollectorDeltas::Delta*,
                                 const SyscollectorSynchronization::SyncMsg*,
                                 const nlohmann::json*>
                        data = SyscollectorSynchronization::GetSyncMsg(message->data()->data());

                    try
                    {
                        logDebug1(WM_VULNSCAN_LOGTAG, "Retry rsync event.");
                        scanOrchestrator->run(data);
                    }
                    catch (const std::exception& e)
                    {
                        logWarn(WM_VULNSCAN_LOGTAG, "Discarded event: %s", e.what());
                    }
                }
                else if (message->type() == BufferType::BufferType_DBSync)
                {
                    std::variant<const SyscollectorDeltas::Delta*,
                                 const SyscollectorSynchronization::SyncMsg*,
                                 const nlohmann::json*>
                        data = SyscollectorDeltas::GetDelta(message->data()->data());

                    try
                    {
                        logDebug1(WM_VULNSCAN_LOGTAG, "Retry delta event.");
                        scanOrchestrator->run(data);
                    }
                    catch (const std::exception& e)
                    {
                        logWarn(WM_VULNSCAN_LOGTAG, "Discarded event: %s", e.what());
                    }
                }
                else
                {
                    logWarn(WM_VULNSCAN_LOGTAG, "Discarded unknown event.");
                }
            }
            else
            {
                logWarn(WM_VULNSCAN_LOGTAG, "Discarded unknown event.");
            }

            dataQueue.pop();
        },
        DELAYED_EVENTS_QUEUE_PATH,
        DELAYED_EVENTS_BULK_SIZE);
}

void VulnerabilityScannerFacade::initDeltasSubscription(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    // Subscription to syscollector delta events.
    m_syscollectorDeltasSubscription =
        std::make_unique<RouterSubscriber>("deltas-syscollector", "vulnerability_scanner_deltas");
    m_syscollectorDeltasSubscription->subscribe(
        // coverity[copy_constructor_call]
        [scanOrchestrator, delayedDispatcher = m_delayedDispatcher](const std::vector<char>& message)
        {
            std::variant<const SyscollectorDeltas::Delta*,
                         const SyscollectorSynchronization::SyncMsg*,
                         const nlohmann::json*>
                data = SyscollectorDeltas::GetDelta(message.data());

            try
            {
                scanOrchestrator->run(data);
            }
            catch (const OsDataException& e)
            {
                flatbuffers::FlatBufferBuilder builder;

                auto object = CreateMessageBufferDirect(builder,
                                                        reinterpret_cast<const std::vector<int8_t>*>(&message),
                                                        BufferType::BufferType_DBSync,
                                                        getSecondsFromEpoch());

                builder.Finish(object);
                auto bufferData = reinterpret_cast<const char*>(builder.GetBufferPointer());
                size_t bufferSize = builder.GetSize();
                const rocksdb::Slice messageSlice(bufferData, bufferSize);
                delayedDispatcher->push(messageSlice);
                logDebug1(WM_VULNSCAN_LOGTAG, "%s. Event postponed", e.what());
            }
            catch (const std::exception& e)
            {
                logError(WM_VULNSCAN_LOGTAG, "ScanOrchestrator::run::Exception: %s", e.what());
            }
        });
}
/**
 * @brief Start the rsync events subscription.
 *
 */
void VulnerabilityScannerFacade::initRsyncSubscription(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    // Subscription to syscollector rsync events.
    m_syscollectorRsyncSubscription =
        std::make_unique<RouterSubscriber>("rsync-syscollector", "vulnerability_scanner_rsync");
    m_syscollectorRsyncSubscription->subscribe(
        // coverity[copy_constructor_call]
        [scanOrchestrator, delayedDispatcher = m_delayedDispatcher](const std::vector<char>& message)
        {
            std::variant<const SyscollectorDeltas::Delta*,
                         const SyscollectorSynchronization::SyncMsg*,
                         const nlohmann::json*>
                data = SyscollectorSynchronization::GetSyncMsg(message.data());

            try
            {
                scanOrchestrator->run(data);
            }
            catch (const OsDataException& e)
            {
                flatbuffers::FlatBufferBuilder builder;

                auto object = CreateMessageBufferDirect(builder,
                                                        reinterpret_cast<const std::vector<int8_t>*>(&message),
                                                        BufferType::BufferType_RSync,
                                                        getSecondsFromEpoch());

                builder.Finish(object);
                auto bufferData = reinterpret_cast<const char*>(builder.GetBufferPointer());
                size_t bufferSize = builder.GetSize();
                const rocksdb::Slice messageSlice(bufferData, bufferSize);
                delayedDispatcher->push(messageSlice);
                logDebug1(WM_VULNSCAN_LOGTAG, "%s. Event postponed.", e.what());
            }
            catch (const std::exception& e)
            {
                logError(WM_VULNSCAN_LOGTAG, "ScanOrchestrator::run::Exception: %s", e.what());
            }
        });
}

void VulnerabilityScannerFacade::initWazuhDBEventSubscription(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    m_wdbAgentEventsSubscription =
        std::make_unique<RouterSubscriber>("wdb-agent-events", "vulnerability_scanner_database");
    m_wdbAgentEventsSubscription->subscribe(
        [scanOrchestrator](const std::vector<char>& message)
        {
            // Create a variant 'data' to store the re-scan action information.
            std::variant<const SyscollectorDeltas::Delta*,
                         const SyscollectorSynchronization::SyncMsg*,
                         const nlohmann::json*>
                data;

            try
            {
                // Create a JSON object 'dataValue'
                nlohmann::json dataValue = nlohmann::json::parse(message.data(), message.data() + message.size());

                // Assign a reference to 'dataValue' to the 'data' variant.
                data = &dataValue;

                // Execute the scan orchestrator with the re-scan action data.
                scanOrchestrator->run(data);
            }
            catch (const std::exception& e)
            {
                logError(WM_VULNSCAN_LOGTAG, "ScanOrchestrator::run::Exception: %s", e.what());
            }
        });
}

void VulnerabilityScannerFacade::vulnerabilityScannerPolicyChange(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    // Check if a rescan is required based on the value of 'm_shouldRescan'.
    if (m_shouldRescan.load())
    {
        logDebug1(WM_VULNSCAN_LOGTAG, "Performing cleanup after reboot.");

        try
        {
            // Create a variant 'data' to store the cleanup action information.
            std::variant<const SyscollectorDeltas::Delta*,
                         const SyscollectorSynchronization::SyncMsg*,
                         const nlohmann::json*>
                data;

            // Create a JSON object 'dataValue' to specify the action as "cleanup."
            nlohmann::json dataValue;
            dataValue["action"] = "cleanup";

            // Assign a reference to 'dataValue' to the 'data' variant.
            data = &dataValue;

            // Execute the scan orchestrator with the cleanup action data.
            scanOrchestrator->run(data);
        }
        catch (const std::exception& e)
        {
            // Log a warning if an exception occurs during cleanup.
            logWarn(WM_VULNSCAN_LOGTAG, "Exception during cleanup: %s", e.what());
        }
    }

    // Create a new thread 'm_rebootThread' using a lambda function.
    m_rebootThread = std::thread(
        [scanOrchestrator, &shouldReboot = m_shouldRescan]()
        {
            // Check if a reboot is required based on the value of 'shouldReboot'.
            if (shouldReboot.load())
            {
                try
                {
                    logDebug1(WM_VULNSCAN_LOGTAG, "Perform re-scan after reboot");

                    // Create a variant 'data' to store the re-scan action information.
                    std::variant<const SyscollectorDeltas::Delta*,
                                 const SyscollectorSynchronization::SyncMsg*,
                                 const nlohmann::json*>
                        data;

                    // Create a JSON object 'dataValue' to specify the action as "reboot."
                    nlohmann::json dataValue;
                    dataValue["action"] = "reboot";

                    // Assign a reference to 'dataValue' to the 'data' variant.
                    data = &dataValue;

                    // Execute the scan orchestrator with the re-scan action data.
                    scanOrchestrator->run(data);
                }
                catch (const std::exception& e)
                {
                    logWarn(WM_VULNSCAN_LOGTAG, "Exception during re-scan: %s", e.what());
                }

                // Reset the 'shouldReboot' flag after the re-scan is performed.
                shouldReboot.store(false);
            }
        });
}

void VulnerabilityScannerFacade::managerScanPolicyChange(std::shared_ptr<ScanOrchestrator> scanOrchestrator)
{
    std::string lastDisableState;
    auto& policyManager = PolicyManager::instance();

    m_stateDB->get("disable_manager_scan", lastDisableState);

    if (!lastDisableState.empty())
    {
        // Check if the last known disable state is "no" and the scanner is now disabled
        if (lastDisableState == "no" &&
            (policyManager.getManagerDisabledScan() == DisableManagerScanStatus::DISABLE_MANAGER_SCAN))
        {
            // Perform manager cleanup
            logInfo(WM_VULNSCAN_LOGTAG, "Vulnerability scanner in manager deactivated. Performing clean-up.");
            m_managerThread = std::thread(
                [scanOrchestrator, stateDB = m_stateDB, &refPolicyManager = policyManager]()
                {
                    try
                    {
                        // Create a JSON object 'dataValue' to specify the action as "deleteAgent."
                        nlohmann::json dataValue;
                        dataValue["action"] = "deleteAgent";
                        dataValue["agent_info"]["agent_id"] = "000";
                        dataValue["agent_info"]["node_name"] = refPolicyManager.getManagerNodeName();

                        // Execute the cleanup action
                        scanOrchestrator->run(&dataValue);
                    }
                    catch (const std::exception& e)
                    {
                        // Log any exceptions that occur during cleanup
                        logError(WM_VULNSCAN_LOGTAG, "Exception during manager clean-up: %s", e.what());
                    }
                });
        }
        // Check if the last known disable state is "yes" and the scanner is now enabled
        else if (lastDisableState == "yes" &&
                 (policyManager.getManagerDisabledScan() == DisableManagerScanStatus::SCAN_MANAGER))
        {
            // Initiate a scan
            logInfo(WM_VULNSCAN_LOGTAG, "Vulnerability scanner in manager activated. Performing scan.");
            m_managerThread = std::thread(
                [scanOrchestrator, stateDB = m_stateDB, &refPolicyManager = policyManager]()
                {
                    try
                    {
                        // Create a JSON object 'dataValue' to specify the action as "scanAgent."
                        nlohmann::json dataValue;
                        dataValue["action"] = "scanAgent";
                        dataValue["agent_info"]["agent_id"] = "000";
                        dataValue["agent_info"]["node_name"] = refPolicyManager.getManagerNodeName();

                        // Execute the scan action
                        scanOrchestrator->run(&dataValue);
                    }
                    catch (const std::exception& e)
                    {
                        // Log any exceptions that occur during scanning
                        logError(WM_VULNSCAN_LOGTAG, "Exception during manager scan: %s", e.what());
                    }
                });
        }
    }

    m_stateDB->put("disable_manager_scan",
                   policyManager.getManagerDisabledScan() == DisableManagerScanStatus::DISABLE_MANAGER_SCAN ? "yes"
                                                                                                            : "no");
}
// TODO: Remove LCOV flags once the implementation of the 'Indexer Connector' module is completed
// LCOV_EXCL_START
void VulnerabilityScannerFacade::start(
    const std::function<void(
        const int, const std::string&, const std::string&, const int, const std::string&, const std::string&, va_list)>&
        logFunction,
    const nlohmann::json& configuration,
    const bool noWaitToStop,
    const bool reloadGlobalMapsStartup,
    const bool initContentUpdater)
{
    try
    {

        m_noWaitToStop = noWaitToStop;

        // Initialize logging
        Log::assignLogFunction(logFunction);

        // Policy manager initialization.
        auto& policyManager = PolicyManager::instance();
        policyManager.initialize(configuration);

        // Create a unique pointer to a RocksDBWrapper instance for managing state information.
        m_stateDB = std::make_unique<Utils::RocksDBWrapper>(VD_STATE_QUEUE_PATH);

        // Initialize a string to store the last known state from the database.
        std::string lastState;

        // Retrieve the last known state from the database and store it in 'lastState'.
        m_stateDB->get("previous_config", lastState);

        // Check if 'lastState' is not empty.
        if (!lastState.empty())
        {
            // Check if the previous state was "no" (disabled) and the current policy allows vulnerability detection.
            if (lastState == "no" && policyManager.isVulnerabilityDetectionEnabled())
            {
                // Log that the vulnerability scanner is being restarted.
                logInfo(WM_VULNSCAN_LOGTAG, "Vulnerability scanner restarted");

                // Update the database state to "yes" (enabled) and set 'm_shouldRescan' to true so the thread can
                // rescan.
                m_stateDB->put("previous_config", "yes");
                m_shouldRescan.store(true);
            }
            // Check if the previous state was "yes" (enabled) and the current policy disables vulnerability detection.
            else if (lastState == "yes" && !policyManager.isVulnerabilityDetectionEnabled())
            {
                // Update the database state to "no" (disabled).
                m_stateDB->put("previous_config", "no");
            }
        }
        else
        {
            // If the value wasn't present, this is the first execution of the refactored module. We store the current
            // value also.
            m_stateDB->put("previous_config", policyManager.isVulnerabilityDetectionEnabled() ? "yes" : "no");
        }

        // Return if the module is disabled.
        if (!policyManager.isVulnerabilityDetectionEnabled())
        {
            logInfo(WM_VULNSCAN_LOGTAG, "Vulnerability scanner module is disabled");
            return;
        }

        // Indexer connector initialization.
        if (policyManager.isIndexerEnabled())
        {
            const auto& indexerConfig = policyManager.getIndexerConfiguration();
            m_indexerConnector =
                std::make_shared<IndexerConnector>(policyManager.getIndexerConfiguration(),
                                                   indexerConfig.contains("template_path")
                                                       ? indexerConfig.at("template_path").get_ref<const std::string&>()
                                                       : VULNERABILITY_SCANNER_TEMPLATE,
                                                   logFunction);
        }

        // Query the current database version.
        std::string databaseVersion;
        m_stateDB->get(VD_DATABASE_VERSION_KEY, databaseVersion);

        // Decompress database content.
        if (decompressDatabase(databaseVersion) && !m_shouldStop.load())
        {
            m_stateDB->put(VD_DATABASE_VERSION_KEY, __ossec_version);

            // Cleanup
            std::filesystem::remove_all(COMPRESSED_DB_PATH);

            logDebug1(WM_VULNSCAN_LOGTAG, "Updated %s key of %s.", VD_DATABASE_VERSION_KEY, VD_STATE_QUEUE_PATH);
        }

        // Database feed manager initialization.
        m_databaseFeedManager = std::make_shared<DatabaseFeedManager>(
            m_indexerConnector, m_shouldStop, m_internalMutex, true, reloadGlobalMapsStartup, initContentUpdater);

        // Socket client initialization to send vulnerability reports.
        initAlertReportDispatcher();

        // Add subscribers for policy updates.
        policyManager.addSubscriber(m_databaseFeedManager);

        // Init Orchestrator
        auto scanOrchestrator = std::make_shared<ScanOrchestrator>(
            m_indexerConnector, m_databaseFeedManager, m_reportDispatcher, m_internalMutex);

        // Policy manager change
        managerScanPolicyChange(scanOrchestrator);

        // Close and reset stateDB
        m_stateDB.reset();

        // Dispatching threads to retry events
        initDelayedEventsDispatcher(scanOrchestrator);

        // Rescan if VD policy change from false to true.
        vulnerabilityScannerPolicyChange(scanOrchestrator);

        // Subscription to syscollector delta events.
        initDeltasSubscription(scanOrchestrator);

        // Subscription to syscollector rsync events.
        initRsyncSubscription(scanOrchestrator);

        // Wazuh DB event subscription.
        initWazuhDBEventSubscription(scanOrchestrator);

        logInfo(WM_VULNSCAN_LOGTAG, "Vulnerability scanner module started");
    }
    catch (const std::exception& e)
    {
        logError(WM_VULNSCAN_LOGTAG, "VulnerabilityScannerFacade::start: %s", e.what());
    }
    catch (...)
    {
        logError(WM_VULNSCAN_LOGTAG, "VulnerabilityScannerFacade::start: Unknown exception");
    }
}
// LCOV_EXCL_STOP

void VulnerabilityScannerFacade::stop()
{
    // Atomic flag section
    if (m_noWaitToStop)
    {
        m_shouldStop.store(true);
    }

    m_retryWait.notify_all();

    // Threads join
    if (m_rebootThread.joinable())
    {
        m_rebootThread.join();
    }

    if (m_managerThread.joinable())
    {
        m_managerThread.join();
    }

    // Reset shared pointers
    m_indexerConnector.reset();
    m_databaseFeedManager.reset();
    m_syscollectorRsyncSubscription.reset();
    m_syscollectorDeltasSubscription.reset();
    m_wdbAgentEventsSubscription.reset();

    // Policy manager teardown
    PolicyManager::instance().teardown();
    m_reportDispatcher.reset();
    m_delayedDispatcher.reset();
    m_stateDB.reset();
}
